Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[firred] bring fired aed to wenet #2680

Merged
merged 12 commits into from
Feb 8, 2025
Merged

[firred] bring fired aed to wenet #2680

merged 12 commits into from
Feb 8, 2025

Conversation

Mddct
Copy link
Collaborator

@Mddct Mddct commented Feb 5, 2025

https://github.com/FireRedTeam/FireRedASR/tree/main

  • convert script
  • precision

目前Firred-LLM 没有release, 等release, 同步看下AED encoder 和LLM encoder区别 ,然后转过来
FireRedTeam/FireRedASR#5 (comment)

@Mddct Mddct marked this pull request as draft February 5, 2025 12:35
@Mddct Mddct force-pushed the Mddct-firedasr branch 2 times, most recently from 2428b08 to aebba69 Compare February 6, 2025 12:28
@Mddct
Copy link
Collaborator Author

Mddct commented Feb 7, 2025

transcribe works !

from argparse import Namespace

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer

config = 'wenet_fired/train.yaml'
args = Namespace(checkpoint='wenet_fired/wenet_firered.pt')

with open(config, 'r') as fin:
    configs = yaml.load(fin, Loader=yaml.FullLoader)

model, _ = init_model(args, configs)

model.eval()
audio_file = '../../test/resources/aishell-BAC009S0724W0121.wav'
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
waveform = waveform.to(torch.float)
feats = kaldi.fbank(waveform,
                    num_mel_bins=80,
                    frame_length=25,
                    frame_shift=10,
                    energy_floor=0.0,
                    sample_frequency=16000)

feats = feats.unsqueeze(0)
feats_lens = torch.tensor([feats.shape[1]], dtype=torch.int64)
results = model.decode(methods=['attention'],
                       speech=feats,
                       speech_lengths=feats_lens,
                       beam_size=10)

tokenizer = init_tokenizer(configs)

for mode, hyps in results.items():
    print('{} {}'.format(mode, tokenizer.detokenize(hyps[0].tokens)))
截屏2025-02-07 14 24 59

@Mddct Mddct marked this pull request as ready for review February 7, 2025 08:43
@Mddct Mddct changed the title [firred] bring fired aed to wenet h [firred] bring fired aed to wenet Feb 7, 2025
@Mddct
Copy link
Collaborator Author

Mddct commented Feb 7, 2025

encoder part

from fireredasr.models.fireredasr import FireRedAsr

wav_paths = ['../../test/resources/aishell-BAC009S0724W0121.wav']

red_model = FireRedAsr.from_pretrained("aed", "pretrained_models/FireRedASR-AED-L")
red_feats, lengths, durs = red_model.feat_extractor(wav_paths)
red_enc_outputs,  _, enc_mask = model.model.encoder(red_feats, lengths)


wenet_encoder = model.encoder
wenet_enc_outs, _ = wenet_encoder(feats,  feats_lens)

torch.allclose(wenet_enc_outs, red_enc_outputs) # True

decoder part

decoder 是标准的transformer decoder, 这里会放aishell的解码结果 layer by layer的比对这里就不做了

  • aishell
ec: 虽 然 知 道 惠 若 琪 的 心 脏 不 太 好 

===========================================================================

Overall -> 0.55 % N=104765 C=104230 S=468 D=67 I=46
Mandarin -> 0.55 % N=104762 C=104230 S=465 D=67 I=46
English -> 0.00 % N=0 C=0 S=0 D=0 I=0
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0

===========================================================================

image

@robin1001 robin1001 merged commit de41dd7 into main Feb 8, 2025
6 checks passed
@robin1001 robin1001 deleted the Mddct-firedasr branch February 8, 2025 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants